{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Confusion Matrix\n",
    "Computes the confusion matrix for a classification problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 2.25495,
     "end_time": "2021-01-22T16:49:05.155095",
     "exception": false,
     "start_time": "2021-01-22T16:49:02.900145",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip3 install tensorflow==2.4.0 scikit-learn==0.24.1 wget==3.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wget\n",
    "wget.download(\n",
    "    'https://raw.githubusercontent.com/'\n",
    "    'elyra-ai/component-library/master/claimed_utils.py'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from claimed_utils import unzip\n",
    "import os.path\n",
    "import glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @dependency codait_utils.ipynb\n",
    "# @dependency model.zip\n",
    "# @dependency data.zip\n",
    "# @param model zip file name\n",
    "# @param data zip file name\n",
    "# @returns confusion matrix inline in jupyter\n",
    "# (next version will return congusion matrix as csv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_zip = os.environ.get('model_zip', 'model.zip')\n",
    "data_zip = os.environ.get('data_zip', 'data.zip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unzip('.', model_zip)\n",
    "unzip('.', data_zip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 6.744013,
     "end_time": "2021-01-22T16:49:13.607761",
     "exception": false,
     "start_time": "2021-01-22T16:49:06.863748",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = keras.models.load_model('model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.191976,
     "end_time": "2021-01-22T16:49:13.806114",
     "exception": false,
     "start_time": "2021-01-22T16:49:13.614138",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "folder = glob.glob(\"data/*\")\n",
    "num_classes = len(folder)\n",
    "\n",
    "batch_size = 32\n",
    "img_height = 400\n",
    "img_width = 400\n",
    "input_shape = (img_width, img_height)\n",
    "\n",
    "\n",
    "train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "    'data',\n",
    "    validation_split=0.2,\n",
    "    subset=\"training\",\n",
    "    seed=123,\n",
    "    image_size=(img_height, img_width),\n",
    "    batch_size=batch_size)\n",
    "\n",
    "val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "    'data',\n",
    "    validation_split=0.2,\n",
    "    subset=\"validation\",\n",
    "    seed=123,\n",
    "    image_size=(img_height, img_width),\n",
    "    batch_size=batch_size)\n",
    "\n",
    "train_ds = train_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))\n",
    "val_ds = val_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 7.376575,
     "end_time": "2021-01-22T16:49:21.189231",
     "exception": false,
     "start_time": "2021-01-22T16:49:13.812656",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "classes = model.predict(\n",
    "    val_ds,\n",
    "    batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.0172,
     "end_time": "2021-01-22T16:49:21.213524",
     "exception": false,
     "start_time": "2021-01-22T16:49:21.196324",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "len(classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.295849,
     "end_time": "2021-01-22T16:49:21.515991",
     "exception": false,
     "start_time": "2021-01-22T16:49:21.220142",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_test = None\n",
    "for features, labels in val_ds:\n",
    "    if y_test is not None:\n",
    "        y_test = np.vstack((y_test, labels.numpy()))\n",
    "    else:\n",
    "        y_test = labels.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.457365,
     "end_time": "2021-01-22T16:49:21.980061",
     "exception": false,
     "start_time": "2021-01-22T16:49:21.522696",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "matrix = metrics.confusion_matrix(\n",
    "    y_test.argmax(axis=1),\n",
    "    classes.argmax(axis=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.011527,
     "end_time": "2021-01-22T16:49:21.998690",
     "exception": false,
     "start_time": "2021-01-22T16:49:21.987163",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.DataFrame(matrix)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 22.024851,
   "end_time": "2021-01-22T16:49:24.278482",
   "environment_variables": {},
   "exception": null,
   "input_path": "/home/jovyan/work/elyra-classification/confusion-matrix.ipynb",
   "output_path": "/home/jovyan/work/elyra-classification/confusion-matrix.ipynb",
   "parameters": {},
   "start_time": "2021-01-22T16:49:02.253631",
   "version": "2.2.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
